{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing On-policy MC control\n",
    "\n",
    "Now, let's learn how to implement the MC control method with epsilon-greedy policy for playing the blackjack game, that is, we will see how can we use the MC control method for\n",
    "finding the optimal policy in the blackjack game:\n",
    "\n",
    "First, let's import the necessary libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import pandas as pd\n",
    "from collections import defaultdict\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a blackjack environment:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('Blackjack-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the dictionary for storing the Q values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "Q = defaultdict(float)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the dictionary for storing the total return of the state-action pair:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_return = defaultdict(float)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the dictionary for storing the count of the number of times a state-action pair is\n",
    "visited:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = defaultdict(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the epsilon-greedy policy\n",
    "\n",
    "We learned that we select actions based on the epsilon-greedy policy, so we define a\n",
    "function called `epsilon_greedy_policy` which takes the state and Q value as an input\n",
    "and returns the action to be performed in the given state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def epsilon_greedy_policy(state,Q):\n",
    "    \n",
    "    #set the epsilon value to 0.5\n",
    "    epsilon = 0.5\n",
    "    \n",
    "    #sample a random value from the uniform distribution, if the sampled value is less than\n",
    "    #epsilon then we select a random action else we select the best action which has maximum Q\n",
    "    #value as shown below\n",
    "    \n",
    "    if random.uniform(0,1) < epsilon:\n",
    "        return env.action_space.sample()\n",
    "    else:\n",
    "        return max(list(range(env.action_space.n)), key = lambda x: Q[(state,x)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generating an episode\n",
    "\n",
    "Now, let's generate an episode using the epsilon-greedy policy. We define a function called\n",
    "`generate_episode` which takes the Q value as an input and returns the episode.\n",
    "\n",
    "First, let's set the number of time steps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_timesteps = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_episode(Q):\n",
    "    \n",
    "    #initialize a list for storing the episode\n",
    "    episode = []\n",
    "    \n",
    "    #initialize the state using the reset function\n",
    "    state = env.reset()\n",
    "    \n",
    "    #then for each time step\n",
    "    for t in range(num_timesteps):\n",
    "        \n",
    "        #select the action according to the epsilon-greedy policy\n",
    "        action = epsilon_greedy_policy(state,Q)\n",
    "        \n",
    "        #perform the selected action and store the next state information\n",
    "        next_state, reward, done, info = env.step(action)\n",
    "        \n",
    "        #store the state, action, reward in the episode list\n",
    "        episode.append((state, action, reward))\n",
    "        \n",
    "        #if the next state is a final state then break the loop else update the next state to the current\n",
    "        #state\n",
    "        if done:\n",
    "            break\n",
    "            \n",
    "        state = next_state\n",
    "\n",
    "    return episode"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Computing the optimal policy\n",
    "\n",
    "Now, let's learn how to compute the optimal policy. First, let's set the number of iterations, that is, the number of episodes, we want to generate:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_iterations = 50000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We learned that in the on-policy control method, we will not be given any policy as an\n",
    "input. So, we initialize a random policy in the first iteration and improve the policy\n",
    "iteratively by computing Q value. Since we extract the policy from the Q function, we don't\n",
    "have to explicitly define the policy. As the Q value improves the policy also improves\n",
    "implicitly. That is, in the first iteration we generate episode by extracting the policy\n",
    "(epsilon-greedy) from the initialized Q function. Over a series of iterations, we will find the\n",
    "optimal Q function and hence we also find the optimal policy. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#for each iteration\n",
    "for i in range(num_iterations):\n",
    "    \n",
    "    #so, here we pass our initialized Q function to generate an episode\n",
    "    episode = generate_episode(Q)\n",
    "    \n",
    "    #get all the state-action pairs in the episode\n",
    "    all_state_action_pairs = [(s, a) for (s,a,r) in episode]\n",
    "    \n",
    "    #store all the rewards obtained in the episode in the rewards list\n",
    "    rewards = [r for (s,a,r) in episode]\n",
    "\n",
    "    #for each step in the episode \n",
    "    for t, (state, action, reward) in enumerate(episode):\n",
    "\n",
    "        #if the state-action pair is occurring for the first time in the episode\n",
    "        if not (state, action) in all_state_action_pairs[0:t]:\n",
    "            \n",
    "            #compute the return R of the state-action pair as the sum of rewards\n",
    "            R = sum(rewards[t:])\n",
    "            \n",
    "            #update total return of the state-action pair\n",
    "            total_return[(state,action)] = total_return[(state,action)] + R\n",
    "            \n",
    "            #update the number of times the state-action pair is visited\n",
    "            N[(state, action)] += 1\n",
    "\n",
    "            #compute the Q value by just taking the average\n",
    "            Q[(state,action)] = total_return[(state, action)] / N[(state, action)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Thus on every iteration, the Q value improves and so does policy.\n",
    "After all the iterations, we can have a look at the Q value of each state-action in the pandas\n",
    "data frame for more clarity.\n",
    "\n",
    "First, let's convert the Q value dictionary to a pandas data\n",
    "frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(Q.items(),columns=['state_action pair','value'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's look at the first few rows of the data frame:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>state_action pair</th>\n",
       "      <th>value</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>((14, 10, False), 0)</td>\n",
       "      <td>-0.641944</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>((14, 10, False), 1)</td>\n",
       "      <td>-0.617698</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>((11, 10, False), 1)</td>\n",
       "      <td>-0.170015</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>((12, 3, False), 0)</td>\n",
       "      <td>-0.180328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>((12, 3, False), 1)</td>\n",
       "      <td>-0.320388</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>((13, 1, False), 0)</td>\n",
       "      <td>-0.752381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>((11, 6, False), 1)</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>((17, 6, False), 0)</td>\n",
       "      <td>-0.118644</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>((10, 9, False), 0)</td>\n",
       "      <td>-0.714286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>((10, 9, False), 1)</td>\n",
       "      <td>-0.041322</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>((14, 4, False), 0)</td>\n",
       "      <td>-0.148289</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       state_action pair     value\n",
       "0   ((14, 10, False), 0) -0.641944\n",
       "1   ((14, 10, False), 1) -0.617698\n",
       "2   ((11, 10, False), 1) -0.170015\n",
       "3    ((12, 3, False), 0) -0.180328\n",
       "4    ((12, 3, False), 1) -0.320388\n",
       "5    ((13, 1, False), 0) -0.752381\n",
       "6    ((11, 6, False), 1)  0.000000\n",
       "7    ((17, 6, False), 0) -0.118644\n",
       "8    ((10, 9, False), 0) -0.714286\n",
       "9    ((10, 9, False), 1) -0.041322\n",
       "10   ((14, 4, False), 0) -0.148289"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head(11)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can observe, we have the Q values for all the state-action pairs. Now we can extract\n",
    "the policy by selecting the action which has maximum Q value in each state. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To learn more how to select action based on this Q value, check the book under the section, implementing on-policy control."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
